{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "# 初始化tokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-0.5B-Instruct')\n",
    "\n",
    "from utils import TokenTemplate\n",
    "             \n",
    "\n",
    "# 示例使用\n",
    "TEMPLATE = \"\"\"Here is a problem, a section of a article that may contain the answer to the problem and a previous memory. Please read carefully and update the memory based on given section to help answer the problem. \n",
    "You should keep the relevant information in the memory while adding new information.\n",
    "\n",
    "<problem> \n",
    "{problem}\n",
    "</problem>\n",
    "\n",
    "<memory>\n",
    "{memory}\n",
    "</memory>\n",
    "\n",
    "<section>\n",
    "{section}\n",
    "</section>\n",
    "\n",
    "Updated memory (should be enclosed in <memory> and </memory>)\n",
    "<memory>\n",
    "\"\"\"\n",
    "\n",
    "processor = TokenTemplate(TEMPLATE)\n",
    "processor.init(tokenizer)\n",
    "\n",
    "# 假设传入的 token ids（在实际使用时应该从模型或其他地方获取）\n",
    "kwarg_text = dict(\n",
    "    problem=\"What is the capital of France?\",\n",
    "    section=\"Here is a introduction to France. France is a country in Western Europe. Its capital is Paris.\",\n",
    "    memory=\"No previous memory\",\n",
    ")\n",
    "kwargs_token_ids = {\n",
    "    k: tokenizer.encode(v, add_special_tokens=False) for k, v in kwarg_text.items()\n",
    "}\n",
    "# 格式化模板\n",
    "formatted_template = processor.format(**kwargs_token_ids)\n",
    "print(tokenizer.decode(formatted_template))\n",
    "print(TEMPLATE.format(**kwarg_text) == tokenizer.decode(formatted_template))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer.__dict__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "message = tokenizer.apply_chat_template([{'role':'user','content':'{message}'}],add_generation_prompt=True, tokenize=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_substrings(lst, startlist, endlist):\n",
    "    start_len = len(startlist)\n",
    "    end_len = len(endlist)\n",
    "    substrings = []\n",
    "    \n",
    "    i = 0\n",
    "    while i <= len(lst) - start_len - end_len:\n",
    "        # 找到开始部分\n",
    "        if lst[i:i+start_len] == startlist:\n",
    "            j = i + start_len\n",
    "            # 找到结束部分\n",
    "            while j <= len(lst) - end_len and lst[j:j+end_len] != endlist:\n",
    "                j += 1\n",
    "            if j <= len(lst) - end_len and lst[j:j+end_len] == endlist:\n",
    "                substrings.append(lst[i+start_len:j])\n",
    "        i += 1\n",
    "    \n",
    "    return substrings\n",
    "\n",
    "# 示例用法\n",
    "lst = [1, 2, 3, 4, 5]\n",
    "startlist = [1, 2] # <summary>\n",
    "endlist = [4, 5] # </summary>\n",
    "print(find_substrings(lst, startlist, endlist))\n"
   ]
  }
 ],
 "metadata": {
  "fileId": "33614652-5eea-4934-a09e-ff6c624c013d",
  "filePath": "/opt/tiger/verl/recurrent/test.ipynb",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
